# two sample test
source("R/data_gen.R")
source("R/competing_methods.R")
source("tests/utils.R")
library(parallel)
library(ggplot2)
library(dplyr)
library(Matrix)
library(RcppHungarian)

# 
get_diff_B = function(seed = 123){
  set.seed(seed)
  # parameter set up
  n = 200
  K = 2 # 2 groups of matrices
  L = 5
  J = 48
  lambda = 50
  zeta = 0.3
  gam = .1
  niter = 100
  
  out = gen_rand_nsbm(n=n, K=K, L=L, J=J,  lambda=lambda, 
                      gam = gam, zeta=zeta, sort_z = T)
  A = out$A
  z_tru = out$z
  xi_tru = out$xi
  eta = out$eta

  # align all matrices in the two groups to the very first matrix
  Xlist = lapply(A, function(As) get_eig_repr(As, L))
  Ut_list = recover_Ut_list(Xlist)
  
  Xt_list = lapply(seq_along(Ut_list), function(j){
    Rt = align_ortho_mats(Ut_list[[1]], Ut_list[[j]])
    Xlist[[j]] %*% Rt
  })
  
  kclust <- kmeans(do.call(rbind, Xt_list), L , nstart = 25)
  xih_vec = kclust$cluster
  xih = split(xih_vec, ceiling(seq_along(xih_vec)/n))
  # average nmi with estimated labels for each matrix 
  hsbm::get_slice_nmi(xih, xi_tru) 
  # compare with the average nmi using SC on individual matrix
  sc = spec_net_clust(A, K, L)
  hsbm::get_slice_nmi(sc$xi, xi_tru)
  
  # estimate B in the first group
  n1 = table(z_tru)[1]
  Bsum1 <- Reduce(`+`, lapply(1:n1, function(j) compute_block_sums(A[[j]], xih[[j]])))
  ns1 <- Reduce(`+`, lapply(1:n1, function(j) {
    nsj <- table(xih[[j]])
    nsj %*% t(nsj) - diag(nsj)
  }))
  Bsum1/ns1
  
  # estimate B in the second group
  Bsum2 <- Reduce(`+`, lapply((n1+1):J, function(j) compute_block_sums(A[[j]], xih[[j]])))
  ns2 <- Reduce(`+`, lapply((n1+1):J, function(j) {
    nsj <- table(xih[[j]])
    nsj %*% t(nsj) - diag(nsj)
  }))
  Bsum2/ns2
  
  # clustering 
  # assume that we do not know the labels 
  # calculate the pair wise distance and use SC to cluster networks
  dist_mat = matrix(0,J,J)
  for(i in 1:J){
    for(j in i:J){
      Rt = align_ortho_mats(Ut_list[[i]], Ut_list[[j]])
      X_temp = rbind(Xlist[[i]], Xlist[[j]] %*% Rt)
      xih_vec = kmeans(X_temp, L, nstart=20)$cluster
      xih = split(xih_vec, ceiling(seq_along(xih_vec)/n))
      Bh1 = nett::estim_dcsbm(A[[i]], xih[[1]])$B
      Bh2 = nett::estim_dcsbm(A[[j]], xih[[2]])$B
      dist_mat[i,j] = norm(Bh1 - Bh2)
      dist_mat[j,i] = dist_mat[i,j]
    }
  }
  
  zh = spec_clust(dist_mat, K, nstart = 25)
  nmi_clust = nett::compute_mutual_info(zh, z_tru)
  
  return(list("diff_tru" = norm(eta[[1]] - eta[[2]]),
              "diff_est" = norm(Bsum1/ns1 - Bsum2/ns2),
              "N" = table(xi_tru[[1]]),
              "nmi_clust" = nmi_clust))
}

get_diff_B(seed = 123) # good estimation on B and clustering 
get_diff_B(seed = 23) # good estimation and clustering
# bad estimation, due to repeated cluster size
# good clustering
get_diff_B(seed = 1234) 

# results would be the same when recovering Ublist


